#include "mlir/IR/MLIRContext.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Tools/StrUtil.h"
#include "llvm/Support/Signals.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>

namespace mlir::triton::gpu {
namespace {

class DumpLayoutTest : public ::testing::Test {
public:
  void SetUp() { ctx.getOrLoadDialect<TritonGPUDialect>(); }

  BlockedEncodingAttr blocked(ArrayRef<unsigned> spt, ArrayRef<unsigned> tpw,
                              ArrayRef<unsigned> wpb, ArrayRef<unsigned> cpg,
                              ArrayRef<unsigned> cSplit, ArrayRef<unsigned> ord,
                              ArrayRef<unsigned> cOrd) {
    return BlockedEncodingAttr::get(
        &ctx, spt, tpw, wpb, ord, CTALayoutAttr::get(&ctx, cpg, cSplit, cOrd));
  }

  void assertSameStr(const std::string &refStr, const std::string &output) {
    if (refStr != output) {
      llvm::outs() << "RefStr =\n"
                   << refStr << "\n"
                   << "\n"
                   << "Output =\n"
                   << output << "\n";
      FAIL() << "Incorrect output string";
    }
  }

protected:
  MLIRContext ctx;
};

TEST_F(DumpLayoutTest, SimpleBlocked) {
  std::string ref =
      R"([ T0:0| T4:0| T8:0|T12:0|T16:0|T20:0|T24:0|T28:0,  T1:0| T5:0| T9:0|T13:0|T17:0|T21:0|T25:0|T29:0,  T2:0| T6:0|T10:0|T14:0|T18:0|T22:0|T26:0|T30:0,  T3:0| T7:0|T11:0|T15:0|T19:0|T23:0|T27:0|T31:0]
)";
  auto blockedLayout = blocked({1}, {8}, {4}, {1}, {1}, {0}, {0});
  auto tensorType = RankedTensorType::get(
      {4}, IntegerType::get(blockedLayout.getContext(), 32), blockedLayout);
  std::string layout = getLayoutStr(tensorType, /*useHWPointOfView=*/false);
  assertSameStr(ref, layout);

  std::string refHWRep =
      R"(Warp0:
(0), (1), (2), (3), (0), (1), (2), (3)
Warp1:
(0), (1), (2), (3), (0), (1), (2), (3)
Warp2:
(0), (1), (2), (3), (0), (1), (2), (3)
Warp3:
(0), (1), (2), (3), (0), (1), (2), (3)
)";
  std::string layoutHW = getLayoutStr(tensorType, /*useHWPointOfView=*/true);
  assertSameStr(refHWRep, layoutHW);
}

TEST_F(DumpLayoutTest, NDTensor) {
  auto blockedLayout = blocked({2, 1, 4}, {2, 2, 2}, {1, 2, 1}, {1, 1, 1},
                               {1, 1, 1}, {2, 1, 0}, {2, 1, 0});
  auto tensorType = RankedTensorType::get(
      {8, 2, 16}, IntegerType::get(blockedLayout.getContext(), 32),
      blockedLayout);
  std::string ref =
      R"([[[  T0:0|  T8:0,   T0:1|  T8:1,   T0:2|  T8:2,   T0:3|  T8:3,   T1:0|  T9:0,   T1:1|  T9:1,   T1:2|  T9:2,   T1:3|  T9:3,   T0:8|  T8:8,   T0:9|  T8:9,  T0:10| T8:10,  T0:11| T8:11,   T1:8|  T9:8,   T1:9|  T9:9,  T1:10| T9:10,  T1:11| T9:11]
[    T2:0| T10:0,   T2:1| T10:1,   T2:2| T10:2,   T2:3| T10:3,   T3:0| T11:0,   T3:1| T11:1,   T3:2| T11:2,   T3:3| T11:3,   T2:8| T10:8,   T2:9| T10:9,  T2:10|T10:10,  T2:11|T10:11,   T3:8| T11:8,   T3:9| T11:9,  T3:10|T11:10,  T3:11|T11:11]]
[[   T0:4|  T8:4,   T0:5|  T8:5,   T0:6|  T8:6,   T0:7|  T8:7,   T1:4|  T9:4,   T1:5|  T9:5,   T1:6|  T9:6,   T1:7|  T9:7,  T0:12| T8:12,  T0:13| T8:13,  T0:14| T8:14,  T0:15| T8:15,  T1:12| T9:12,  T1:13| T9:13,  T1:14| T9:14,  T1:15| T9:15]
[    T2:4| T10:4,   T2:5| T10:5,   T2:6| T10:6,   T2:7| T10:7,   T3:4| T11:4,   T3:5| T11:5,   T3:6| T11:6,   T3:7| T11:7,  T2:12|T10:12,  T2:13|T10:13,  T2:14|T10:14,  T2:15|T10:15,  T3:12|T11:12,  T3:13|T11:13,  T3:14|T11:14,  T3:15|T11:15]]
[[   T4:0| T12:0,   T4:1| T12:1,   T4:2| T12:2,   T4:3| T12:3,   T5:0| T13:0,   T5:1| T13:1,   T5:2| T13:2,   T5:3| T13:3,   T4:8| T12:8,   T4:9| T12:9,  T4:10|T12:10,  T4:11|T12:11,   T5:8| T13:8,   T5:9| T13:9,  T5:10|T13:10,  T5:11|T13:11]
[    T6:0| T14:0,   T6:1| T14:1,   T6:2| T14:2,   T6:3| T14:3,   T7:0| T15:0,   T7:1| T15:1,   T7:2| T15:2,   T7:3| T15:3,   T6:8| T14:8,   T6:9| T14:9,  T6:10|T14:10,  T6:11|T14:11,   T7:8| T15:8,   T7:9| T15:9,  T7:10|T15:10,  T7:11|T15:11]]
[[   T4:4| T12:4,   T4:5| T12:5,   T4:6| T12:6,   T4:7| T12:7,   T5:4| T13:4,   T5:5| T13:5,   T5:6| T13:6,   T5:7| T13:7,  T4:12|T12:12,  T4:13|T12:13,  T4:14|T12:14,  T4:15|T12:15,  T5:12|T13:12,  T5:13|T13:13,  T5:14|T13:14,  T5:15|T13:15]
[    T6:4| T14:4,   T6:5| T14:5,   T6:6| T14:6,   T6:7| T14:7,   T7:4| T15:4,   T7:5| T15:5,   T7:6| T15:6,   T7:7| T15:7,  T6:12|T14:12,  T6:13|T14:13,  T6:14|T14:14,  T6:15|T14:15,  T7:12|T15:12,  T7:13|T15:13,  T7:14|T15:14,  T7:15|T15:15]]
[[  T0:16| T8:16,  T0:17| T8:17,  T0:18| T8:18,  T0:19| T8:19,  T1:16| T9:16,  T1:17| T9:17,  T1:18| T9:18,  T1:19| T9:19,  T0:24| T8:24,  T0:25| T8:25,  T0:26| T8:26,  T0:27| T8:27,  T1:24| T9:24,  T1:25| T9:25,  T1:26| T9:26,  T1:27| T9:27]
[   T2:16|T10:16,  T2:17|T10:17,  T2:18|T10:18,  T2:19|T10:19,  T3:16|T11:16,  T3:17|T11:17,  T3:18|T11:18,  T3:19|T11:19,  T2:24|T10:24,  T2:25|T10:25,  T2:26|T10:26,  T2:27|T10:27,  T3:24|T11:24,  T3:25|T11:25,  T3:26|T11:26,  T3:27|T11:27]]
[[  T0:20| T8:20,  T0:21| T8:21,  T0:22| T8:22,  T0:23| T8:23,  T1:20| T9:20,  T1:21| T9:21,  T1:22| T9:22,  T1:23| T9:23,  T0:28| T8:28,  T0:29| T8:29,  T0:30| T8:30,  T0:31| T8:31,  T1:28| T9:28,  T1:29| T9:29,  T1:30| T9:30,  T1:31| T9:31]
[   T2:20|T10:20,  T2:21|T10:21,  T2:22|T10:22,  T2:23|T10:23,  T3:20|T11:20,  T3:21|T11:21,  T3:22|T11:22,  T3:23|T11:23,  T2:28|T10:28,  T2:29|T10:29,  T2:30|T10:30,  T2:31|T10:31,  T3:28|T11:28,  T3:29|T11:29,  T3:30|T11:30,  T3:31|T11:31]]
[[  T4:16|T12:16,  T4:17|T12:17,  T4:18|T12:18,  T4:19|T12:19,  T5:16|T13:16,  T5:17|T13:17,  T5:18|T13:18,  T5:19|T13:19,  T4:24|T12:24,  T4:25|T12:25,  T4:26|T12:26,  T4:27|T12:27,  T5:24|T13:24,  T5:25|T13:25,  T5:26|T13:26,  T5:27|T13:27]
[   T6:16|T14:16,  T6:17|T14:17,  T6:18|T14:18,  T6:19|T14:19,  T7:16|T15:16,  T7:17|T15:17,  T7:18|T15:18,  T7:19|T15:19,  T6:24|T14:24,  T6:25|T14:25,  T6:26|T14:26,  T6:27|T14:27,  T7:24|T15:24,  T7:25|T15:25,  T7:26|T15:26,  T7:27|T15:27]]
[[  T4:20|T12:20,  T4:21|T12:21,  T4:22|T12:22,  T4:23|T12:23,  T5:20|T13:20,  T5:21|T13:21,  T5:22|T13:22,  T5:23|T13:23,  T4:28|T12:28,  T4:29|T12:29,  T4:30|T12:30,  T4:31|T12:31,  T5:28|T13:28,  T5:29|T13:29,  T5:30|T13:30,  T5:31|T13:31]
[   T6:20|T14:20,  T6:21|T14:21,  T6:22|T14:22,  T6:23|T14:23,  T7:20|T15:20,  T7:21|T15:21,  T7:22|T15:22,  T7:23|T15:23,  T6:28|T14:28,  T6:29|T14:29,  T6:30|T14:30,  T6:31|T14:31,  T7:28|T15:28,  T7:29|T15:29,  T7:30|T15:30,  T7:31|T15:31]]]
)";
  std::string layout = getLayoutStr(tensorType, /*useHWPointOfView=*/false);
  assertSameStr(ref, layout);
  std::string refHWRep =
      R"(Warp0:
(0,0, 0), (0,0, 4), (0,1, 0), (0,1, 4), (2,0, 0), (2,0, 4), (2,1, 0), (2,1, 4)
(0,0, 1), (0,0, 5), (0,1, 1), (0,1, 5), (2,0, 1), (2,0, 5), (2,1, 1), (2,1, 5)
(0,0, 2), (0,0, 6), (0,1, 2), (0,1, 6), (2,0, 2), (2,0, 6), (2,1, 2), (2,1, 6)
(0,0, 3), (0,0, 7), (0,1, 3), (0,1, 7), (2,0, 3), (2,0, 7), (2,1, 3), (2,1, 7)
(1,0, 0), (1,0, 4), (1,1, 0), (1,1, 4), (3,0, 0), (3,0, 4), (3,1, 0), (3,1, 4)
(1,0, 1), (1,0, 5), (1,1, 1), (1,1, 5), (3,0, 1), (3,0, 5), (3,1, 1), (3,1, 5)
(1,0, 2), (1,0, 6), (1,1, 2), (1,1, 6), (3,0, 2), (3,0, 6), (3,1, 2), (3,1, 6)
(1,0, 3), (1,0, 7), (1,1, 3), (1,1, 7), (3,0, 3), (3,0, 7), (3,1, 3), (3,1, 7)
(0,0, 8), (0,0,12), (0,1, 8), (0,1,12), (2,0, 8), (2,0,12), (2,1, 8), (2,1,12)
(0,0, 9), (0,0,13), (0,1, 9), (0,1,13), (2,0, 9), (2,0,13), (2,1, 9), (2,1,13)
(0,0,10), (0,0,14), (0,1,10), (0,1,14), (2,0,10), (2,0,14), (2,1,10), (2,1,14)
(0,0,11), (0,0,15), (0,1,11), (0,1,15), (2,0,11), (2,0,15), (2,1,11), (2,1,15)
(1,0, 8), (1,0,12), (1,1, 8), (1,1,12), (3,0, 8), (3,0,12), (3,1, 8), (3,1,12)
(1,0, 9), (1,0,13), (1,1, 9), (1,1,13), (3,0, 9), (3,0,13), (3,1, 9), (3,1,13)
(1,0,10), (1,0,14), (1,1,10), (1,1,14), (3,0,10), (3,0,14), (3,1,10), (3,1,14)
(1,0,11), (1,0,15), (1,1,11), (1,1,15), (3,0,11), (3,0,15), (3,1,11), (3,1,15)
(4,0, 0), (4,0, 4), (4,1, 0), (4,1, 4), (6,0, 0), (6,0, 4), (6,1, 0), (6,1, 4)
(4,0, 1), (4,0, 5), (4,1, 1), (4,1, 5), (6,0, 1), (6,0, 5), (6,1, 1), (6,1, 5)
(4,0, 2), (4,0, 6), (4,1, 2), (4,1, 6), (6,0, 2), (6,0, 6), (6,1, 2), (6,1, 6)
(4,0, 3), (4,0, 7), (4,1, 3), (4,1, 7), (6,0, 3), (6,0, 7), (6,1, 3), (6,1, 7)
(5,0, 0), (5,0, 4), (5,1, 0), (5,1, 4), (7,0, 0), (7,0, 4), (7,1, 0), (7,1, 4)
(5,0, 1), (5,0, 5), (5,1, 1), (5,1, 5), (7,0, 1), (7,0, 5), (7,1, 1), (7,1, 5)
(5,0, 2), (5,0, 6), (5,1, 2), (5,1, 6), (7,0, 2), (7,0, 6), (7,1, 2), (7,1, 6)
(5,0, 3), (5,0, 7), (5,1, 3), (5,1, 7), (7,0, 3), (7,0, 7), (7,1, 3), (7,1, 7)
(4,0, 8), (4,0,12), (4,1, 8), (4,1,12), (6,0, 8), (6,0,12), (6,1, 8), (6,1,12)
(4,0, 9), (4,0,13), (4,1, 9), (4,1,13), (6,0, 9), (6,0,13), (6,1, 9), (6,1,13)
(4,0,10), (4,0,14), (4,1,10), (4,1,14), (6,0,10), (6,0,14), (6,1,10), (6,1,14)
(4,0,11), (4,0,15), (4,1,11), (4,1,15), (6,0,11), (6,0,15), (6,1,11), (6,1,15)
(5,0, 8), (5,0,12), (5,1, 8), (5,1,12), (7,0, 8), (7,0,12), (7,1, 8), (7,1,12)
(5,0, 9), (5,0,13), (5,1, 9), (5,1,13), (7,0, 9), (7,0,13), (7,1, 9), (7,1,13)
(5,0,10), (5,0,14), (5,1,10), (5,1,14), (7,0,10), (7,0,14), (7,1,10), (7,1,14)
(5,0,11), (5,0,15), (5,1,11), (5,1,15), (7,0,11), (7,0,15), (7,1,11), (7,1,15)
Warp1:
(0,0, 0), (0,0, 4), (0,1, 0), (0,1, 4), (2,0, 0), (2,0, 4), (2,1, 0), (2,1, 4)
(0,0, 1), (0,0, 5), (0,1, 1), (0,1, 5), (2,0, 1), (2,0, 5), (2,1, 1), (2,1, 5)
(0,0, 2), (0,0, 6), (0,1, 2), (0,1, 6), (2,0, 2), (2,0, 6), (2,1, 2), (2,1, 6)
(0,0, 3), (0,0, 7), (0,1, 3), (0,1, 7), (2,0, 3), (2,0, 7), (2,1, 3), (2,1, 7)
(1,0, 0), (1,0, 4), (1,1, 0), (1,1, 4), (3,0, 0), (3,0, 4), (3,1, 0), (3,1, 4)
(1,0, 1), (1,0, 5), (1,1, 1), (1,1, 5), (3,0, 1), (3,0, 5), (3,1, 1), (3,1, 5)
(1,0, 2), (1,0, 6), (1,1, 2), (1,1, 6), (3,0, 2), (3,0, 6), (3,1, 2), (3,1, 6)
(1,0, 3), (1,0, 7), (1,1, 3), (1,1, 7), (3,0, 3), (3,0, 7), (3,1, 3), (3,1, 7)
(0,0, 8), (0,0,12), (0,1, 8), (0,1,12), (2,0, 8), (2,0,12), (2,1, 8), (2,1,12)
(0,0, 9), (0,0,13), (0,1, 9), (0,1,13), (2,0, 9), (2,0,13), (2,1, 9), (2,1,13)
(0,0,10), (0,0,14), (0,1,10), (0,1,14), (2,0,10), (2,0,14), (2,1,10), (2,1,14)
(0,0,11), (0,0,15), (0,1,11), (0,1,15), (2,0,11), (2,0,15), (2,1,11), (2,1,15)
(1,0, 8), (1,0,12), (1,1, 8), (1,1,12), (3,0, 8), (3,0,12), (3,1, 8), (3,1,12)
(1,0, 9), (1,0,13), (1,1, 9), (1,1,13), (3,0, 9), (3,0,13), (3,1, 9), (3,1,13)
(1,0,10), (1,0,14), (1,1,10), (1,1,14), (3,0,10), (3,0,14), (3,1,10), (3,1,14)
(1,0,11), (1,0,15), (1,1,11), (1,1,15), (3,0,11), (3,0,15), (3,1,11), (3,1,15)
(4,0, 0), (4,0, 4), (4,1, 0), (4,1, 4), (6,0, 0), (6,0, 4), (6,1, 0), (6,1, 4)
(4,0, 1), (4,0, 5), (4,1, 1), (4,1, 5), (6,0, 1), (6,0, 5), (6,1, 1), (6,1, 5)
(4,0, 2), (4,0, 6), (4,1, 2), (4,1, 6), (6,0, 2), (6,0, 6), (6,1, 2), (6,1, 6)
(4,0, 3), (4,0, 7), (4,1, 3), (4,1, 7), (6,0, 3), (6,0, 7), (6,1, 3), (6,1, 7)
(5,0, 0), (5,0, 4), (5,1, 0), (5,1, 4), (7,0, 0), (7,0, 4), (7,1, 0), (7,1, 4)
(5,0, 1), (5,0, 5), (5,1, 1), (5,1, 5), (7,0, 1), (7,0, 5), (7,1, 1), (7,1, 5)
(5,0, 2), (5,0, 6), (5,1, 2), (5,1, 6), (7,0, 2), (7,0, 6), (7,1, 2), (7,1, 6)
(5,0, 3), (5,0, 7), (5,1, 3), (5,1, 7), (7,0, 3), (7,0, 7), (7,1, 3), (7,1, 7)
(4,0, 8), (4,0,12), (4,1, 8), (4,1,12), (6,0, 8), (6,0,12), (6,1, 8), (6,1,12)
(4,0, 9), (4,0,13), (4,1, 9), (4,1,13), (6,0, 9), (6,0,13), (6,1, 9), (6,1,13)
(4,0,10), (4,0,14), (4,1,10), (4,1,14), (6,0,10), (6,0,14), (6,1,10), (6,1,14)
(4,0,11), (4,0,15), (4,1,11), (4,1,15), (6,0,11), (6,0,15), (6,1,11), (6,1,15)
(5,0, 8), (5,0,12), (5,1, 8), (5,1,12), (7,0, 8), (7,0,12), (7,1, 8), (7,1,12)
(5,0, 9), (5,0,13), (5,1, 9), (5,1,13), (7,0, 9), (7,0,13), (7,1, 9), (7,1,13)
(5,0,10), (5,0,14), (5,1,10), (5,1,14), (7,0,10), (7,0,14), (7,1,10), (7,1,14)
(5,0,11), (5,0,15), (5,1,11), (5,1,15), (7,0,11), (7,0,15), (7,1,11), (7,1,15)
)";
  std::string layoutHW = getLayoutStr(tensorType, /*useHWPointOfView=*/true);
  assertSameStr(refHWRep, layoutHW);
}

} // anonymous namespace
} // namespace mlir::triton::gpu

int main(int argc, char *argv[]) {
  llvm::sys::PrintStackTraceOnErrorSignal(argv[0]);
  testing::InitGoogleTest(&argc, argv);
  return RUN_ALL_TESTS();
}
